from abc import ABCMeta, abstractmethod
import itertools, subprocess, time, csv
from .common import cgroupHelper


class BaseCollect(object):
    __metaclass__ = ABCMeta

    def __init__(self, interval, timeout=5):
        self.interval = interval
        self.timeout = timeout

    @abstractmethod
    def collect(self):
        pass

    def write(self, data, outputFile):
        row = [time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())]
        row.extend(data)

        with open(outputFile, "a+") as csvfile:
            csvWriter = csv.writer(csvfile, delimiter=",")
            csvWriter.writerow(row)


class PressureCollect(object):
    """cgrouv2Collect collect cgroup v2 info
    memory.pressure:
        some avg10 avg60 avg300 total
        full avg10 avg60 avg300 total
    cpu.pressure:
        some avg10 avg60 avg300 total
    io.pressure
        some avg10 avg60 avg300 total
        full avg10 avg60 avg300 total

    Attributes:
    default: /sys/fs/cgroup/system.slice
        1) collect all containers psi, please specify cgroupPath of docker or containerd.
            output: # of containers
        2) collect process psi, specify the cgroupPath of the process.
    """

    def __init__(
        self, v=None, driver=None, ctrID=None, cgroupPath="/sys/fs/cgroup/system.slice"
    ):
        self.v = v
        self.driver = driver
        self.ctrID = ctrID
        self.cgroupPath = cgroupPath
        self.pressures = ["memory.pressure", "io.pressure", "cpu.pressure"]
        self.columns = PressureCollect.get_columns()

    def collect(self):
        if self.v == "v1":
            return [-1] * len(self.columns)

        if self.v:
            path = cgroupHelper.get_cpu_path(
                v=self.v,
                driver=self.driver,
                ctrID=self.ctrID,
                cgroupPrefix=self.cgroupPath,
            )
        else:
            path = self.cgroupPath

        row = []
        for pres in self.pressures:
            file = path + "/" + pres
            row.extend(self.pressure_parse(file))
        return row

    def pressure_parse(self, file):
        psi = []
        with open(file, mode="r", encoding="utf-8") as f:
            for line in f:
                datas = line.rstrip().split()
                psi.extend([data.split("=")[1] for data in datas[1:]])
        return psi

    @classmethod
    def get_columns(cls):
        """column names for pressures"""
        columns = []
        prefixs = [
            ["memory", "io", "cpu"],
            ["some", "full"],
            ["avg10", "avg60", "avg300", "total"],
        ]
        cols = [".".join(ele) for ele in list(itertools.product(*prefixs))]
        columns.extend(cols[:-4])  # there is no cpu.full.*, so abandon last 4
        return columns


class RdtCollect(object):
    def __init__(
        self, pid=None, v=None, driver=None, ctrID=None, cgroupPath="/sys/fs/cgroup"
    ):
        self.v, self.ctrID, self.driver, self.cgroupPath = v, ctrID, driver, cgroupPath
        self.pid = pid
        self.columns = RdtCollect.get_columns()

    @classmethod
    def get_columns(cls):
        return ["CPI", "MKPI", "LLC", "LMB", "RMB"]

    def collect(self):
        if self.v:
            path = cgroupHelper.get_cpu_path(
                v=self.v,
                driver=self.driver,
                ctrID=self.ctrID,
                cgroupPrefix=self.cgroupPath,
            )
            if self.v == "v1":
                taskPath = "/".join([path, "tasks"])
            else:
                taskPath = "/".join([path, "cgroup.procs"])

            with open(taskPath, mode="r") as f:
                pids = f.read().replace("\n", ",")
        else:
            pids = str(self.pid)

        # TODO: wrap with thread?
        cmd = ["pqos -p ", '"all:[', pids, ']"', "-I -i 10 -t 2"]
        try:
            output = subprocess.check_output(
                "".join(cmd), stderr=subprocess.STDOUT, shell=True
            )
        except subprocess.CalledProcessError as e:
            print(
                "%s : execute %s failed, %s"
                % (
                    time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                    "".join(cmd),
                    e.output,
                )
            )
            return [-1] * len(self.columns)

        output = output.rstrip().split("\n")[1]
        ipc, mkpi, llc, lmb, rmb = output[2:7]
        if int(ipc) == 0:
            cpi == -1
        else:
            cpi = 1 / int(ipc)
        # TODO: lmx ?

        return [cpi, mkpi, llc, lmb, rmb]
